{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Federated Learning Training Plan: Create Plan\n",
    "\n",
    "Let's try to make protobuf-serializable Training Plan and Model that work after deserializing :)\n",
    "\n",
    "Current list of problems:\n",
    " * No support for autograd in Plan tracing (.backward() doesn't work inside the Plan).\n",
    " * `tensor.shape` value seem to be recorded as constant during Plan tracing, so we need to pass `batch_size`, can't take it from tensor itself.\n",
    " * Plan needs a list of all Model params in the argument list, it would be nicer if this list is figured out automatically so you just pass the Model (not sure it's solvable jit might not accept the model as ScriptModule input?)\n",
    " * Plan doesn't return input args from Plan, e.g. if they were updated with inplace operation\n",
    " * others? \n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "text": [
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was 'C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tf_encrypted/operations/secure_random/secure_random_module_tf_1.13.1.so'\n"
     ],
     "output_type": "stream"
    },
    {
     "name": "stdout",
     "text": [
      "Setting up Sandbox...\n",
      "Done!\n"
     ],
     "output_type": "stream"
    },
    {
     "data": {
      "text/plain": "<torch._C.Generator at 0x22f01652890>"
     },
     "metadata": {},
     "output_type": "execute_result",
     "execution_count": 1
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import syft as sy\n",
    "import torch as th\n",
    "from torch import jit\n",
    "from syft.serde import protobuf\n",
    "import os\n",
    "from syft.execution.state import State\n",
    "from syft.execution.placeholder import PlaceHolder\n",
    "\n",
    "\n",
    "sy.hook(globals())\n",
    "# force protobuf serialization for tensors\n",
    "hook.local_worker.framework = None\n",
    "th.random.manual_seed(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "This utility function will serialize any object to protobuf binary and save to a file."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "def serializeToBinPb(worker, obj, filename):\n",
    "    pb = protobuf.serde._bufferize(worker, obj)\n",
    "    bin = pb.SerializeToString()\n",
    "    print(\"Writing %s to %s/%s\" % (obj.__class__.__name__, os.getcwd(), filename))\n",
    "    with open(filename, \"wb\") as f:\n",
    "        f.write(bin)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 1: Define the model\n",
    "\n",
    "This model will train on MNIST data, it's very simple yet can demonstrate learning process.\n",
    "There're 2 linear layers: \n",
    "\n",
    "* Linear 784x392\n",
    "* ReLU\n",
    "* Linear 392x10 \n",
    "\n",
    "Not using nn.Module or nn.Linear for now, just vanilla class and tensors.\n",
    "No autograd, gradients are hand-coded. \n",
    "\n",
    "As no loops supported inside Plan, \n",
    "we can't iterate over parameters so everything that works with params\n",
    "(get/set, step) is moved into the model to make Plan code more generic.  "
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "class Net():\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.W1 = th.randn(392, 784) / th.sqrt(th.tensor(784.))\n",
    "        self.b1 = th.zeros(392)\n",
    "        self.W2 = th.randn(10, 392) / th.sqrt(th.tensor(392.))\n",
    "        self.b2 = th.zeros(10)\n",
    "        self.update_fn = None\n",
    "\n",
    "    def forward(self, X):\n",
    "        self.Z1 = X @ self.W1.t() + self.b1\n",
    "        self.A1 = th.nn.functional.relu(self.Z1)\n",
    "        self.Z2 = self.A1 @ self.W2.t() + self.b2\n",
    "        return self.Z2\n",
    "\n",
    "    def get_params(self):\n",
    "        return self.W1, self.b1, self.W2, self.b2\n",
    "\n",
    "    def set_params(self, *model_params):\n",
    "        self.W1, self.b1, self.W2, self.b2 = model_params\n",
    "\n",
    "    def grad(self, X, error):\n",
    "        Z1_grad = (error @ self.W2) * (self.Z1 > 0).float()\n",
    "        W1_grad = Z1_grad.t() @ X\n",
    "        b1_grad = Z1_grad.sum(0)\n",
    "        W2_grad = error.t() @ self.A1 \n",
    "        b2_grad = error.sum(0)\n",
    "        return W1_grad, b1_grad, W2_grad, b2_grad\n",
    "\n",
    "model = Net()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 2: Define Training Plan\n",
    "### Loss function \n",
    "Batch size needs to be passed because otherwise `target.shape[0]` will be saved as `1` constant during Plan trace with dummy data.\n",
    "Grad is also returned here. "
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "def cross_entropy_with_logits(output, target, batch_size):\n",
    "    probs = th.nn.functional.softmax(output, dim=1)\n",
    "    loss = -(target * th.log(probs)).mean()\n",
    "    loss_grad = (probs - target) / (batch_size * target.shape[1])\n",
    "    return probs, loss, loss_grad"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Optimization function\n",
    " \n",
    "Just updates weights with grad*lr.\n",
    "\n",
    "Note: can't do inplace update and return this value from the Plan, which is potentially bad for memory."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "def naive_sgd(param, grad, **kwargs):\n",
    "    return param - grad * kwargs['lr']"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Training Plan procedure"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "model_params = model.get_params()\n",
    "\n",
    "# define plan input dimensions\n",
    "X_size = (-1, 784)\n",
    "y_size = (-1, 10)\n",
    "scalar_size = (1,)\n",
    "model_params_shapes = [p.shape for p in model_params]\n",
    "\n",
    "args_shape = [\n",
    "    X_size,  # X\n",
    "    y_size,  # y\n",
    "    scalar_size,  # batch_size\n",
    "    scalar_size,  # lr\n",
    "    *model_params_shapes  # *model_params\n",
    "]\n",
    "\n",
    "@sy.func2plan(args_shape=args_shape)\n",
    "def training_plan(X, y, batch_size, lr, *model_params):\n",
    "    # inject params into model\n",
    "    model.set_params(*model_params)\n",
    "\n",
    "    # forward pass\n",
    "    output = model.forward(X)\n",
    "    \n",
    "    # loss\n",
    "    probs, loss, loss_grad = cross_entropy_with_logits(output, y, batch_size)\n",
    "\n",
    "    # backprop\n",
    "    grads = model.grad(X, loss_grad)\n",
    "\n",
    "    # step\n",
    "    updated_params = [naive_sgd(param, grads[i], lr=lr) \n",
    "                      for i, param in enumerate(model_params)]\n",
    "    \n",
    "    # accuracy\n",
    "    pred = th.argmax(probs, dim=1)\n",
    "    target = th.argmax(y, dim=1)\n",
    "    acc = pred.eq(target).float().sum() / batch_size\n",
    "\n",
    "    return (\n",
    "        loss,\n",
    "        acc,\n",
    "        *updated_params\n",
    "    )"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Let's look inside the Plan and print out the list of operations recorded."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "Inputs:\n",
      "X\n",
      "y\n",
      "batch_size\n",
      "lr\n",
      "W1\n",
      "b1\n",
      "W2\n",
      "b2\n",
      "\n",
      "Actions:\n",
      "var_2 = W1.t()\n",
      "var_4 = X.__matmul__(var_2)\n",
      "var_6 = var_4.__add__(b1)\n",
      "var_7 = torch.nn.functional.relu(var_6)\n",
      "var_9 = W2.t()\n",
      "var_10 = var_7.__matmul__(var_9)\n",
      "var_12 = var_10.__add__(b2)\n",
      "var_13 = torch.nn.functional.softmax(var_12, {'dim': 1})\n",
      "var_14 = torch.log(var_13)\n",
      "var_16 = y.__mul__(var_14)\n",
      "var_17 = var_16.mean()\n",
      "loss = var_17.__neg__()\n",
      "var_19 = var_13.__sub__(y)\n",
      "var_21 = batch_size.__mul__(10)\n",
      "var_22 = var_19.__truediv__(var_21)\n",
      "var_23 = var_22.__matmul__(W2)\n",
      "var_24 = var_6.__gt__(0)\n",
      "var_25 = var_24.float()\n",
      "var_26 = var_23.__mul__(var_25)\n",
      "var_27 = var_26.t()\n",
      "var_28 = var_27.__matmul__(X)\n",
      "var_29 = var_26.sum(0)\n",
      "var_30 = var_22.t()\n",
      "var_31 = var_30.__matmul__(var_7)\n",
      "var_32 = var_22.sum(0)\n",
      "var_34 = var_28.__mul__(lr)\n",
      "upd_W1 = W1.__sub__(var_34)\n",
      "var_36 = var_29.__mul__(lr)\n",
      "upd_b1 = b1.__sub__(var_36)\n",
      "var_38 = var_31.__mul__(lr)\n",
      "upd_W2 = W2.__sub__(var_38)\n",
      "var_40 = var_32.__mul__(lr)\n",
      "upd_b2 = b2.__sub__(var_40)\n",
      "var_42 = torch.argmax(var_13, {'dim': 1})\n",
      "var_43 = torch.argmax(y, {'dim': 1})\n",
      "var_44 = var_42.eq(var_43)\n",
      "var_45 = var_44.float()\n",
      "var_46 = var_45.sum()\n",
      "acc = var_46.__truediv__(batch_size)\n",
      "\n",
      "Outputs:\n",
      "loss\n",
      "acc\n",
      "upd_W1\n",
      "upd_b1\n",
      "upd_W2\n",
      "upd_b2\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "input_names = [\n",
    "    \"X\", \"y\", \"batch_size\", \"lr\",\n",
    "    \"W1\", \"b1\", \"W2\", \"b2\"\n",
    "]\n",
    "\n",
    "output_names = [\n",
    "    \"loss\", \"acc\",\n",
    "    \"upd_W1\", \"upd_b1\", \"upd_W2\", \"upd_b2\"\n",
    "]\n",
    "\n",
    "def placeholderToStr(ph: PlaceHolder):\n",
    "    ret = ''\n",
    "    for tag in ph.tags:\n",
    "        if re.search('input', tag):\n",
    "            idx = int(tag.split(\"-\")[1])\n",
    "            return input_names[idx]\n",
    "        elif re.search('output', tag):\n",
    "            idx = int(tag.split(\"-\")[1])\n",
    "            return output_names[idx]\n",
    "        else:\n",
    "            ret = \"var_\" + tag.split(\"#\")[1]\n",
    "    return ret\n",
    "            \n",
    "def argToStr(arg):\n",
    "    if isinstance(arg, tuple):\n",
    "        return \", \".join(list(map(argToStr, arg)))\n",
    "    if isinstance(arg, PlaceHolder):\n",
    "        return placeholderToStr(arg)\n",
    "    else:\n",
    "        return str(arg)\n",
    "\n",
    "def tag_sort(keyword):\n",
    "    def extract_key(placeholder):\n",
    "        for tag in placeholder.tags:\n",
    "            if keyword in tag:\n",
    "                return int(tag.split(\"-\")[-1])\n",
    "    return extract_key\n",
    "\n",
    "print(\"Inputs:\")\n",
    "for inp in sorted(training_plan.find_placeholders(\"input\"), key=tag_sort(\"input\")):\n",
    "    print(argToStr(inp))\n",
    "\n",
    "print(\"\\nActions:\")\n",
    "for action in training_plan.actions:\n",
    "    expr = [argToStr(action.return_ids), ' = ']\n",
    "    if action.target is None:\n",
    "        expr += [action.name, '(']\n",
    "    else:\n",
    "        expr += [argToStr(action.target), '.', action.name, '(']\n",
    "\n",
    "    if len(action.args):\n",
    "        expr += argToStr(action.args)\n",
    "\n",
    "    if action.kwargs:\n",
    "        expr += ', ', str(action.kwargs)\n",
    "    \n",
    "    expr += [')']\n",
    "    print(\"\".join(expr))\n",
    "\n",
    "print(\"\\nOutputs:\")\n",
    "for out in sorted(training_plan.find_placeholders(\"output\"), key=tag_sort(\"output\")):\n",
    "    print(argToStr(out))\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 3: JIT Trace Training Plan\n",
    "\n",
    "Note: Plan expects everything to be a tensor.\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "def __call__(argument_0: Tensor,\n",
      "    argument_1: Tensor,\n",
      "    argument_2: Tensor,\n",
      "    argument_3: Tensor,\n",
      "    argument_4: Tensor,\n",
      "    argument_5: Tensor,\n",
      "    argument_6: Tensor,\n",
      "    argument_7: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n",
      "  _0 = torch.matmul(argument_0, torch.t(argument_4))\n",
      "  _1 = torch.add(_0, argument_5, alpha=1)\n",
      "  _2 = torch.relu(_1)\n",
      "  _3 = torch.add(torch.matmul(_2, torch.t(argument_6)), argument_7, alpha=1)\n",
      "  _4 = torch.softmax(_3, 1, None)\n",
      "  _5 = torch.mean(torch.mul(argument_1, torch.log(_4)), dtype=None)\n",
      "  _6 = torch.neg(_5)\n",
      "  _7 = torch.div(torch.sub(_4, argument_1, alpha=1), torch.mul(argument_2, CONSTANTS.c0))\n",
      "  _8 = torch.matmul(_7, argument_6)\n",
      "  _9 = torch.to(torch.gt(_1, 0), 6, False, False, None)\n",
      "  _10 = torch.mul(_8, _9)\n",
      "  _11 = torch.matmul(torch.t(_10), argument_0)\n",
      "  _12 = torch.sum(_10, [0], False, dtype=None)\n",
      "  _13 = torch.matmul(torch.t(_7), _2)\n",
      "  _14 = torch.sum(_7, [0], False, dtype=None)\n",
      "  _15 = torch.sub(argument_4, torch.mul(_11, argument_3), alpha=1)\n",
      "  _16 = torch.sub(argument_5, torch.mul(_12, argument_3), alpha=1)\n",
      "  _17 = torch.sub(argument_6, torch.mul(_13, argument_3), alpha=1)\n",
      "  _18 = torch.sub(argument_7, torch.mul(_14, argument_3), alpha=1)\n",
      "  _19 = torch.eq(torch.argmax(_4, 1, False), torch.argmax(argument_1, 1, False))\n",
      "  _20 = torch.sum(torch.to(_19, 6, False, False, None), dtype=None)\n",
      "  _21 = (_6, torch.div(_20, argument_2), _15, _16, _17, _18)\n",
      "  return _21\n",
      "\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "X_trace = th.randn(2, 784)\n",
    "y_trace = th.randn(2, 10)\n",
    "lr = th.tensor(0.001)\n",
    "batch_size = th.tensor(32)\n",
    "\n",
    "# remove actual function, so plan executes recorded ops\n",
    "training_plan.forward = None\n",
    "\n",
    "# jit trace\n",
    "training_plan_torchscript = th.jit.trace(training_plan.__call__, (X_trace, y_trace, batch_size, lr, *model_params))\n",
    "\n",
    "# Let's see\n",
    "print(training_plan_torchscript.code)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 4: Serialize!\n",
    "\n",
    "Now it's time to serialize model params and plans to protobuf and save them for further usage:\n",
    " * In \"Execute Plan\" notebook, we load and execute these plans & model, from Python.\n",
    " * In \"Host Plan\" notebook, we send these plans & model to PyGrid, so it can be executed from other worker (e.g. syft.js).\n",
    "\n",
    "**NOTE:**\n",
    " * We don't serialize full Model, only weights. How the Model is serialized is TBD.\n",
    "   State is suitable protobuf class to wrap list of Model params tensors.\n",
    " * Plan containing list of operations is serialized to syft_proto.execution.v1.Plan \n",
    " while torchscript Plan is serialized to syft_proto.types.torch.v1.ScriptFunction. \n",
    " In the future they will converge to syft_proto.execution.v1.Plan, see https://github.com/OpenMined/PySyft/issues/2994#issuecomment-595333791"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "Writing Plan to e:\\ml/tp_ops.pb\n",
      "Writing ScriptFunction to e:\\ml/tp_ts.pb\n",
      "Writing State to e:\\ml/model_params.pb\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "serializeToBinPb(hook.local_worker, training_plan, \"tp_ops.pb\")\n",
    "serializeToBinPb(hook.local_worker, training_plan_torchscript, \"tp_ts.pb\")\n",
    "\n",
    "# wrap weights in State to serialize\n",
    "model_params_state = State(\n",
    "    owner=hook.local_worker,\n",
    "    state_placeholders=[PlaceHolder().instantiate(param) for param in model_params]\n",
    ")\n",
    "\n",
    "serializeToBinPb(hook.local_worker, model_params_state, \"model_params.pb\")\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}